{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Testing a Pre-Trained Network"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we know how to train an LSTM network, let's see try seeing the results on our own input text!\n",
    "\n",
    "As we showed at the end of the Oriole notebook, we can load in a pretrained LSTM network using Tensorflow's Saver object. Before creating that object, we first have to create our Tensorflow graph. \n",
    "\n",
    "First, we declare some of our hyperparamters. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "numDimensions = 300\n",
    "maxSeqLength = 250\n",
    "batchSize = 24\n",
    "lstmUnits = 64\n",
    "numClasses = 2\n",
    "iterations = 100000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, we'll load in our data structures."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "wordsList = np.load('wordsList.npy').tolist()\n",
    "wordsList = [word.decode('UTF-8') for word in wordsList] #Encode words as UTF-8\n",
    "wordVectors = np.load('wordVectors.npy')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll create our graph. This is the same code as we saw in the main Oriole notebook. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "tf.reset_default_graph()\n",
    "\n",
    "labels = tf.placeholder(tf.float32, [batchSize, numClasses])\n",
    "input_data = tf.placeholder(tf.int32, [batchSize, maxSeqLength])\n",
    "\n",
    "data = tf.Variable(tf.zeros([batchSize, maxSeqLength, numDimensions]),dtype=tf.float32)\n",
    "data = tf.nn.embedding_lookup(wordVectors,input_data)\n",
    "\n",
    "lstmCell = tf.contrib.rnn.BasicLSTMCell(lstmUnits)\n",
    "lstmCell = tf.contrib.rnn.DropoutWrapper(cell=lstmCell, output_keep_prob=0.25)\n",
    "value, _ = tf.nn.dynamic_rnn(lstmCell, data, dtype=tf.float32)\n",
    "\n",
    "weight = tf.Variable(tf.truncated_normal([lstmUnits, numClasses]))\n",
    "bias = tf.Variable(tf.constant(0.1, shape=[numClasses]))\n",
    "value = tf.transpose(value, [1, 0, 2])\n",
    "last = tf.gather(value, int(value.get_shape()[0]) - 1)\n",
    "prediction = (tf.matmul(last, weight) + bias)\n",
    "\n",
    "correctPred = tf.equal(tf.argmax(prediction,1), tf.argmax(labels,1))\n",
    "accuracy = tf.reduce_mean(tf.cast(correctPred, tf.float32))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we load in the network. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "sess = tf.InteractiveSession()\n",
    "saver = tf.train.Saver()\n",
    "saver.restore(sess, tf.train.latest_checkpoint('models'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before we input our own text, let's first define a couple of functions. The first is a function to make sure the sentence is in the proper format, and the second is a function that obtains the word vectors for each of the words in a given sentence. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Removes punctuation, parentheses, question marks, etc., and leaves only alphanumeric characters\n",
    "import re\n",
    "strip_special_chars = re.compile(\"[^A-Za-z0-9 ]+\")\n",
    "\n",
    "def cleanSentences(string):\n",
    "    string = string.lower().replace(\"<br />\", \" \")\n",
    "    return re.sub(strip_special_chars, \"\", string.lower())\n",
    "\n",
    "def getSentenceMatrix(sentence):\n",
    "    arr = np.zeros([batchSize, maxSeqLength])\n",
    "    sentenceMatrix = np.zeros([batchSize,maxSeqLength], dtype='int32')\n",
    "    cleanedSentence = cleanSentences(sentence)\n",
    "    split = cleanedSentence.split()\n",
    "    for indexCounter,word in enumerate(split):\n",
    "        try:\n",
    "            sentenceMatrix[0,indexCounter] = wordsList.index(word)\n",
    "        except ValueError:\n",
    "            sentenceMatrix[0,indexCounter] = 399999 #Vector for unkown words\n",
    "    return sentenceMatrix"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we can create our input text. Try modifying the value for inputText, and see how the outputs change!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "inputText = \"That movie was terrible.\"\n",
    "inputMatrix = getSentenceMatrix(inputText)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Negative Sentiment\n"
     ]
    }
   ],
   "source": [
    "predictedSentiment = sess.run(prediction, {input_data: inputMatrix})[0]\n",
    "# predictedSentiment[0] represents output score for positive sentiment\n",
    "# predictedSentiment[1] represents output score for negative sentiment\n",
    "\n",
    "if (predictedSentiment[0] > predictedSentiment[1]):\n",
    "    print \"Positive Sentiment\"\n",
    "else:\n",
    "    print \"Negative Sentiment\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "secondInputText = \"That movie was the best one I have ever seen.\"\n",
    "secondInputMatrix = getSentenceMatrix(secondInputText)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Positive Sentiment\n"
     ]
    }
   ],
   "source": [
    "predictedSentiment = sess.run(prediction, {input_data: secondInputMatrix})[0]\n",
    "if (predictedSentiment[0] > predictedSentiment[1]):\n",
    "    print \"Positive Sentiment\"\n",
    "else:\n",
    "    print \"Negative Sentiment\""
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda root]",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
